/**
 * Copyright 2021 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "infer_shape_mat_mul.h"
#include <algorithm>
#include "op_desc.h"
#include "attr_utils.h"
#include "operator.h"
using namespace ge;
namespace ops {
namespace {
const int64_t UNKNOWN_DIM = -1;
const std::pair<int64_t, int64_t> NORMALIZE_FULL_RANGE = {1, std::numeric_limits<int64_t>::max()};
const int64_t VALUE_UNKNOWN_RANK = -2;
const std::pair<int64_t, int64_t> FULL_RANGE = {1, -1};
const std::pair<int64_t, int64_t> EMPTY_RANGE = {0, 0};

bool IsDimValid(int64_t dim) {
  return dim >= VALUE_UNKNOWN_RANK && dim != 0;
}
bool IsUnknownRank(const ge::Shape &shape) {
  return shape.GetDimNum() == 1 && shape.GetDim(0) == VALUE_UNKNOWN_RANK;
}
bool IsUnknownShape(const Shape &shape) {
  if (shape.GetDimNum() == 0) {
    return false;
  }

  for (size_t i = 0; i < shape.GetDimNum(); ++i) {
    if (shape.GetDim(i) == UNKNOWN_DIM) {
      return true;
    }
  }

  return false;
}

bool IntersectDimensionAndRange(const std::string &op_name,
                                const int64_t dim_a,
                                const int64_t dim_b,
                                const std::pair<int64_t, int64_t> &range_a,
                                const std::pair<int64_t, int64_t> &range_b,
                                int64_t &dim,
                                std::pair<int64_t, int64_t> &range) {
  // | b\a        | -1,(y1,y2)                      | y          |
  // | ---------- | ------------------------------- | ---------- |
  // | -1,(x1,x2) | -1,(max(x1,y1),min(x2,y)) check | y check    |
  // | x          | x check                         | x==y check |

  if (dim_a > 0 && dim_b > 0) {
    if (dim_a != dim_b || range_a != range_b) {
      return false;
    }
    dim = dim_a;
    range = range_a;
    return true;
  }

  if (dim_a == UNKNOWN_DIM && dim_b == UNKNOWN_DIM) {
    auto lower_bound = std::max(range_a.first, range_b.first);
    auto upper_bound = std::min(range_a.second, range_b.second);
    if (lower_bound > upper_bound) {
      return false;
    }

    range.first = lower_bound;
    range.second = upper_bound;
    return true;
  }

  if (dim_a == UNKNOWN_DIM) {
    if (range_a.first <= dim_b && dim_b <= range_a.second) {
      dim = dim_b;
      range = range_b;
      return true;
    }
    return false;
  }
  if (range_b.first <= dim_a && dim_a <= range_b.second) {
    dim = dim_a;
    range = range_a;
    return true;
  }
  return false;
}

bool BroadcastDimensionAndRange(const char *op_name,
                                const int64_t dim_a,
                                const int64_t dim_b,
                                const std::pair<int64_t, int64_t> &range_a,
                                const std::pair<int64_t, int64_t> &range_b,
                                int64_t &dim,
                                std::pair<int64_t, int64_t> &range) {
  // | b\a        | -1,(1,y)        | -1,(y1,y2)                       | 0          | 1          | y          |
  // | ---------- | --------------- | -------------------------------- | ---------- | ---------- | ---------- |
  // | -1,(1,x)   | -1,(1,max(x,y)) | -1,(y1,y2)                       | -1,(1,x)   | -1,(1,x)   | y check    |
  // | -1,(x1,x2) | -1,(x1,x2)      | -1,(max(x1,y1),min(x2,y2)) check | -1,(x1,x2) | -1,(x1,x2) | y check    |
  // | 0          | -1,(1,y)        | -1,(y1,y2)                       | 0          | 1          | y          |
  // | 1          | -1,(1,y)        | -1,(y1,y2)                       | 1          | 1          | y          |
  // | x          | x check         | x check                          | x          | x          | x==y check |

  if (dim_a == 0) {
    dim = dim_b;
    range = range_b;
    return true;
  }
  if (dim_b == 0) {
    dim = dim_a;
    range = range_a;
    return true;
  }

  if (dim_a == 1) {
    dim = dim_b;
    range = range_b;
    return true;
  }
  if (dim_b == 1) {
    dim = dim_a;
    range = range_a;
    return true;
  }

  if (dim_a > 1 && dim_b > 1) {
    if (dim_a != dim_b) {
      return false;
    }
    dim = dim_a;
    range = range_a;
    return true;
  }
  if (dim_a > 1) {
    if (range_b.first <= dim_a && dim_a <= range_b.second) {
      dim = dim_a;
      range = range_a;
      return true;
    }
    return false;
  }
  if (dim_b > 1) {
    if (range_a.first <= dim_b && dim_b <= range_a.second) {
      dim = dim_b;
      range = range_b;
      return true;
    }
    return false;
  }

  if (range_a.first == 1 && range_b.first == 1) {
    dim = UNKNOWN_DIM;
    range = {1, std::max(range_a.second, range_b.second)};
    return true;
  }
  if (range_a.first > 1 && range_b.first > 1) {
    auto lower_bound = std::max(range_a.first, range_b.first);
    auto upper_bound = std::min(range_a.second, range_b.second);
    if (lower_bound > upper_bound) {
      return false;
    }
    dim = UNKNOWN_DIM;
    range = {lower_bound, upper_bound};
    return true;
  }
  if (range_a.first > 1) {
    dim = dim_a;
    range = range_a;
  } else {
    dim = dim_b;
    range = range_b;
  }

  return true;
}
void NormalizeRange(const std::string &op_name, const int64_t dim,
                    const std::pair<int64_t, int64_t> &shape_range,
                    std::pair<int64_t, int64_t> &range) {
  if (dim == UNKNOWN_DIM && (shape_range == EMPTY_RANGE || shape_range == FULL_RANGE)) {
    range = NORMALIZE_FULL_RANGE;
    if (shape_range == EMPTY_RANGE) {
    }
  } else if (dim > 0) {
    range = {dim, dim};
  } else {
    range = shape_range;
  }
}
}

const int64_t InferShapeMatMul::base_len = 2;

bool InferShapeMatMul::IsStaticShape() {
  if (!IsUnknownShape(shape_a) && !IsUnknownShape(shape_b) && !IsUnknownShape(shape_bias)) {
    return true;
  }
  return false;
}

InferShapeMatMul::InferShapeMatMul(const char *op_name,
                                   const ge::Shape &shape_a, const ge::Shape &shape_b, const ge::Shape &shape_bias,
                                   ge::Range &range_a, ge::Range &range_b, ge::Range &range_bias,
                                   bool trans_a, bool trans_b,
                                   ge::Shape &shape_out, ge::Range &range_out,
                                   bool has_batch)
                                   : op_name(op_name),
                                   shape_a(shape_a),
                                   shape_b(shape_b),
                                   shape_bias(shape_bias),
                                   range_a(range_a),
                                   range_b(range_b),
                                   range_bias(range_bias),
                                   trans_a(trans_a),
                                   trans_b(trans_b),
                                   shape_out(shape_out),
                                   range_out(range_out),
                                   has_batch(has_batch) {
  num_dim = std::max(std::max(shape_a.GetDimNum(), shape_b.GetDimNum()), shape_bias.GetDimNum());
  num_dim = std::max(base_len, num_dim);

  infer_shape_a = vector<int64_t>(num_dim);
  infer_range_a = vector<std::pair<int64_t, int64_t>>(num_dim);

  infer_shape_b = vector<int64_t>(num_dim);
  infer_range_b = vector<std::pair<int64_t, int64_t>>(num_dim);

  if (shape_bias.GetDimNum() != 0) {
    infer_shape_bias = vector<int64_t>(num_dim);
    infer_range_bias = vector<std::pair<int64_t, int64_t>>(num_dim);
  }

  shape_out.SetDimNum(num_dim);
  range_out.SetDimNum(num_dim);

  if (IsStaticShape()) {
    range_a.SetDimNum(0);
    range_b.SetDimNum(0);
    range_bias.SetDimNum(0);
  }
}

void InferShapeMatMul::NormalizeShapeAndRange() {
  if (IsUnknownRank(shape_a)) {
    for (int i = num_dim - base_len; i < num_dim; ++i) {
      infer_shape_a[i] = UNKNOWN_DIM;
      infer_range_a[i] = NORMALIZE_FULL_RANGE;
    }
  } else {
    for (size_t i = 0; i < shape_a.GetDimNum(); ++i) {
      infer_shape_a[num_dim + i - shape_a.GetDimNum()] = shape_a.GetDim(i);
      infer_range_a[num_dim + i - shape_a.GetDimNum()] = range_a.GetDimRange(i);
    }
  }

  if (IsUnknownRank(shape_b)) {
    for (int i = num_dim - base_len; i < num_dim; ++i) {
      infer_shape_b[i] = UNKNOWN_DIM;
      infer_range_b[i] = NORMALIZE_FULL_RANGE;
    }
  } else {
    for (size_t i = 0; i < shape_b.GetDimNum(); ++i) {
      infer_shape_b[num_dim + i - shape_b.GetDimNum()] = shape_b.GetDim(i);
      infer_range_b[num_dim + i - shape_b.GetDimNum()] = range_b.GetDimRange(i);
    }
  }

  if (shape_bias.GetDimNum() != 0) {
    if (IsUnknownRank(shape_bias)) {
      infer_shape_bias[num_dim - 1] = UNKNOWN_DIM;
      infer_range_bias[num_dim - 1] = NORMALIZE_FULL_RANGE;
    } else {
      for (size_t i = 0; i < shape_bias.GetDimNum(); ++i) {
        infer_shape_bias[num_dim + i - shape_bias.GetDimNum()] = shape_bias.GetDim(i);
        infer_range_bias[num_dim + i - shape_bias.GetDimNum()] = range_bias.GetDimRange(i);
      }
    }
  }

  for (auto i = num_dim - shape_a.GetDimNum(); i < num_dim; ++i) {
    NormalizeRange(op_name, infer_shape_a[i], infer_range_a[i], infer_range_a[i]);
  }

  for (auto i = num_dim - shape_b.GetDimNum(); i < num_dim; ++i) {
    NormalizeRange(op_name, infer_shape_b[i], infer_range_b[i], infer_range_b[i]);
  }

  for (auto i = num_dim - shape_bias.GetDimNum(); i < num_dim; ++i) {
    NormalizeRange(op_name, infer_shape_bias[i], infer_range_bias[i], infer_range_bias[i]);
  }
}

bool InferShapeMatMul::InferMKN() {
  int64_t idx_m = trans_a ? num_dim - 1 : num_dim - 2;
  int64_t idx_k_a = trans_a ? num_dim - 2 : num_dim - 1;
  int64_t idx_k_b = trans_b ? num_dim - 1 : num_dim - 2;
  int64_t idx_n_b = trans_b ? num_dim - 2 : num_dim - 1;

  auto m = infer_shape_a[idx_m];
  auto k_a = infer_shape_a[idx_k_a];
  auto k_b = infer_shape_b[idx_k_b];
  auto n_b = infer_shape_b[idx_n_b];
  auto n = n_b;

  int64_t k;
  if (!IsDimValid(m) || !IsDimValid(k_a) || !IsDimValid(k_b) || !IsDimValid(n_b)) {
    return false;
  }

  std::pair<int64_t, int64_t> range_k, range_n = infer_range_b[idx_n_b];
  if (k_a > 0 && k_b > 0 && k_a != k_b) {
    return false;
  } else if (k_a < 0 && k_b < 0) {
    if (!IntersectDimensionAndRange(op_name, k_a, k_b, infer_range_a[idx_k_a], infer_range_b[idx_k_b], k, range_k)) {
      return false;
    }
  }
  if (shape_bias.GetDimNum() != 0) {
    int64_t idx_n_bias = num_dim - 1;
    int64_t n_bias = infer_shape_bias[idx_n_bias];
    if (!IsDimValid(n_bias)) {
      return false;
    }

    if (!IntersectDimensionAndRange(op_name, n_b, n_bias, infer_range_b[idx_n_b], infer_range_bias[idx_n_bias], n,
                                    range_n)) {
      return false;
    }
  }

  shape_out.SetDim(num_dim - 2, m);
  shape_out.SetDim(num_dim - 1, n);
  range_out.SetDimRange(num_dim - 2, infer_range_a[idx_m] == NORMALIZE_FULL_RANGE ? FULL_RANGE : infer_range_a[idx_m]);
  range_out.SetDimRange(num_dim - 1, range_n == NORMALIZE_FULL_RANGE ? FULL_RANGE : range_n);

  return true;
}

bool InferShapeMatMul::InferBatch() {
  for (auto i = 0; i < num_dim - 2; ++i) {
    if (!BroadcastDimensionAndRange(op_name, infer_shape_a[i], infer_shape_b[i], infer_range_a[i], infer_range_b[i],
                                    shape_out.GetDim(i), range_out.GetDimRange(i))) {
      return false;
    }

    if (shape_bias.GetDimNum() != 0) {
      if (!BroadcastDimensionAndRange(op_name, shape_out.GetDim(i), infer_shape_bias[i], range_out.GetDimRange(i), infer_range_bias[i],
                                      shape_out.GetDim(i), range_out.GetDimRange(i))) {
        return false;
      }
    }

    if (range_out.GetDimRange(i) == NORMALIZE_FULL_RANGE) {
      range_out.SetDimRange(i, FULL_RANGE);
    }
  }
  return true;
}

void InferShapeMatMul::SimplifyShapeAndRange() {
  for (int i = 0; i < range_out.GetDimsCount(); i++) {
    if (range_out.GetDimRange(i).first == range_out.GetDimRange(i).second) {
      shape_out.SetDim(i, range_out.GetDimRange(i).first);
    }
  }
}

bool InferShapeMatMul::GetShapeRangeOfOutput() {
  if (!has_batch && IsUnknownRank(shape_a) && IsUnknownRank(shape_b) &&
      (shape_bias.GetDimNum() == 0 || IsUnknownRank(shape_bias))) {
    shape_out.SetDimNum(1);
    shape_out.SetDim(0, VALUE_UNKNOWN_RANK);
    range_out.SetDimNum(0);
    return true;
  }
  if (has_batch && (IsUnknownRank(shape_a) || IsUnknownRank(shape_b) || IsUnknownRank(shape_bias))) {
    shape_out.SetDimNum(1);
    shape_out.SetDim(0, VALUE_UNKNOWN_RANK);
    range_out.SetDimNum(0);
    return true;
  }

  NormalizeShapeAndRange();

  if (!InferMKN()) {
    return false;
  }

  if (!InferBatch()) {
    return false;
  }

  SimplifyShapeAndRange();
  return true;
}

Status GetMatMulOutputShape(OpDescPtr op_desc,
                            ge::Shape &shape_out,
                            ge::Range &shape_range_out,
                            const std::string &name_attr, bool has_batch) {
  auto desc_a = op_desc->MutableInputDesc(0);
  auto desc_b = op_desc->MutableInputDesc(1);
  const ge::Shape &shape_a = desc_a->GetShape();
  const ge::Shape &shape_b = desc_b->GetShape();
  ge::Range &shape_range_a = desc_a->MutableShapeRange();
  ge::Range &shape_range_b = desc_b->MutableShapeRange();

  Shape shape_bias;
  Range shape_range_bias;
  if (op_desc->ValidInputIndex(2)) {
    shape_bias = op_desc->GetInputDesc(2).GetShape();
    shape_range_bias = op_desc->GetInputDesc(2).GetShapeRange();
  }

  bool trans_a = false;
  if (!AttrUtils::GetBoolById(*op_desc, 0, trans_a)) {
    return FAILED;
  }

  bool trans_b = false;
  if (!AttrUtils::GetBoolById(*op_desc, 1, trans_b)) {
    return FAILED;
  }

  auto obj = InferShapeMatMul(op_desc->GetName(), shape_a, shape_b, shape_bias, shape_range_a, shape_range_b,
                              shape_range_bias, trans_a, trans_b, shape_out, shape_range_out, has_batch);
  if (!obj.GetShapeRangeOfOutput()) {
    return FAILED;
  }

  return SUCCESS;
}

#define IMPLEMT_COMMON_INFERFUNC(F) Status F(const Operator &op)

IMPLEMT_COMMON_INFERFUNC(MatMulInferShape) {
  auto op_desc = op.GetNode()->GetOpDesc();
  auto tensordesc_output = op_desc->MutableOutputDesc(0);
  auto &tensordesc_x1 = op_desc->GetInputDesc(0);
  auto &tensordesc_x2 = op_desc->GetInputDesc(1);

  ge::Shape shape_out;
  ge::Range shape_range_out;
  if (SUCCESS != GetMatMulOutputShape(op_desc, shape_out, shape_range_out, "transpose", false)) {
    return FAILED;
  }

  tensordesc_output->SetShapeRange(shape_range_out);
  tensordesc_output->SetShape(shape_out);
  tensordesc_output->SetOriginShape(shape_out);
  tensordesc_output->SetDataType(tensordesc_x1.GetDataType());
  return SUCCESS;
}
}